{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# How to create a custom checkpointer using MongoDB\n",
        "\n",
        "When creating LangGraph agents, you can also set them up so that they persist their state. This allows you to do things like interact with an agent multiple times and have it remember previous interactions.\n",
        "\n",
        "This example shows how to use `MongoDB` as the backend for persisting checkpoint state.\n",
        "\n",
        "NOTE: this is just an example implementation. You can implement your own checkpointer using a different database or modify this one as long as it conforms to the `BaseCheckpointSaver` interface."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Checkpointer implementation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "%%capture --no-stderr\n",
        "%pip install -U langgraph pymongo"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [],
      "source": [
        "import pickle\n",
        "from contextlib import AbstractContextManager\n",
        "from types import TracebackType\n",
        "from typing import Any, Dict, Iterator, Optional\n",
        "\n",
        "from langchain_core.runnables import RunnableConfig\n",
        "from typing_extensions import Self\n",
        "\n",
        "from langgraph.checkpoint.base import (\n",
        "    BaseCheckpointSaver,\n",
        "    Checkpoint,\n",
        "    CheckpointMetadata,\n",
        "    CheckpointTuple,\n",
        "    SerializerProtocol,\n",
        ")\n",
        "from langgraph.serde.jsonplus import JsonPlusSerializer\n",
        "from pymongo import MongoClient\n",
        "\n",
        "\n",
        "class JsonPlusSerializerCompat(JsonPlusSerializer):\n",
        "    \"\"\"A serializer that supports loading pickled checkpoints for backwards compatibility.\n",
        "\n",
        "    This serializer extends the JsonPlusSerializer and adds support for loading pickled\n",
        "    checkpoints. If the input data starts with b\"\\x80\" and ends with b\".\", it is treated\n",
        "    as a pickled checkpoint and loaded using pickle.loads(). Otherwise, the default\n",
        "    JsonPlusSerializer behavior is used.\n",
        "\n",
        "    Examples:\n",
        "        >>> import pickle\n",
        "        >>> from langgraph.checkpoint.sqlite import JsonPlusSerializerCompat\n",
        "        >>>\n",
        "        >>> serializer = JsonPlusSerializerCompat()\n",
        "        >>> pickled_data = pickle.dumps({\"key\": \"value\"})\n",
        "        >>> loaded_data = serializer.loads(pickled_data)\n",
        "        >>> print(loaded_data)  # Output: {\"key\": \"value\"}\n",
        "        >>>\n",
        "        >>> json_data = '{\"key\": \"value\"}'.encode(\"utf-8\")\n",
        "        >>> loaded_data = serializer.loads(json_data)\n",
        "        >>> print(loaded_data)  # Output: {\"key\": \"value\"}\n",
        "    \"\"\"\n",
        "\n",
        "    def loads(self, data: bytes) -> Any:\n",
        "        if data.startswith(b\"\\x80\") and data.endswith(b\".\"):\n",
        "            return pickle.loads(data)\n",
        "        return super().loads(data)\n",
        "\n",
        "\n",
        "class MongoDBSaver(AbstractContextManager, BaseCheckpointSaver):\n",
        "    \"\"\"A checkpoint saver that stores checkpoints in a MongoDB database.\n",
        "\n",
        "    Args:\n",
        "        client (pymongo.MongoClient): The MongoDB client.\n",
        "        db_name (str): The name of the database to use.\n",
        "        collection_name (str): The name of the collection to use.\n",
        "        serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to JsonPlusSerializerCompat.\n",
        "\n",
        "    Examples:\n",
        "\n",
        "        >>> from pymongo import MongoClient\n",
        "        >>> from langgraph.checkpoint.mongodb import MongoDBSaver\n",
        "        >>> from langgraph.graph import StateGraph\n",
        "        >>>\n",
        "        >>> builder = StateGraph(int)\n",
        "        >>> builder.add_node(\"add_one\", lambda x: x + 1)\n",
        "        >>> builder.set_entry_point(\"add_one\")\n",
        "        >>> builder.set_finish_point(\"add_one\")\n",
        "        >>> client = MongoClient(\"mongodb://localhost:27017/\")\n",
        "        >>> memory = MongoDBSaver(client, \"checkpoints\", \"checkpoints\")\n",
        "        >>> graph = builder.compile(checkpointer=memory)\n",
        "        >>> config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
        "        >>> graph.get_state(config)\n",
        "        >>> result = graph.invoke(3, config)\n",
        "        >>> graph.get_state(config)\n",
        "        StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': '1', 'thread_ts': '2024-05-04T06:32:42.235444+00:00'}}, parent_config=None)\n",
        "    \"\"\"\n",
        "\n",
        "    serde = JsonPlusSerializerCompat()\n",
        "\n",
        "    client: MongoClient\n",
        "    db_name: str\n",
        "    collection_name: str\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        client: MongoClient,\n",
        "        db_name: str,\n",
        "        collection_name: str,\n",
        "        *,\n",
        "        serde: Optional[SerializerProtocol] = None,\n",
        "    ) -> None:\n",
        "        super().__init__(serde=serde)\n",
        "        self.client = client\n",
        "        self.db_name = db_name\n",
        "        self.collection_name = collection_name\n",
        "        self.collection = client[db_name][collection_name]\n",
        "\n",
        "    def __enter__(self) -> Self:\n",
        "        return self\n",
        "\n",
        "    def __exit__(\n",
        "        self,\n",
        "        __exc_type: Optional[type[BaseException]],\n",
        "        __exc_value: Optional[BaseException],\n",
        "        __traceback: Optional[TracebackType],\n",
        "    ) -> Optional[bool]:\n",
        "        return True\n",
        "\n",
        "    def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:\n",
        "        \"\"\"Get a checkpoint tuple from the database.\n",
        "\n",
        "        This method retrieves a checkpoint tuple from the MongoDB database based on the\n",
        "        provided config. If the config contains a \"thread_ts\" key, the checkpoint with\n",
        "        the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint\n",
        "        for the given thread ID is retrieved.\n",
        "\n",
        "        Args:\n",
        "            config (RunnableConfig): The config to use for retrieving the checkpoint.\n",
        "\n",
        "        Returns:\n",
        "            Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.\n",
        "        \"\"\"\n",
        "        if config[\"configurable\"].get(\"thread_ts\"):\n",
        "            query = {\n",
        "                \"thread_id\": config[\"configurable\"][\"thread_id\"],\n",
        "                \"thread_ts\": config[\"configurable\"][\"thread_ts\"],\n",
        "            }\n",
        "        else:\n",
        "            query = {\"thread_id\": config[\"configurable\"][\"thread_id\"]}\n",
        "        result = self.collection.find(query).sort(\"thread_ts\", -1).limit(1)\n",
        "        for doc in result:\n",
        "            return CheckpointTuple(\n",
        "                config,\n",
        "                self.serde.loads(doc[\"checkpoint\"]),\n",
        "                self.serde.loads(doc[\"metadata\"]),\n",
        "                (\n",
        "                    {\n",
        "                        \"configurable\": {\n",
        "                            \"thread_id\": doc[\"thread_id\"],\n",
        "                            \"thread_ts\": doc[\"parent_ts\"],\n",
        "                        }\n",
        "                    }\n",
        "                    if doc.get(\"parent_ts\")\n",
        "                    else None\n",
        "                ),\n",
        "            )\n",
        "\n",
        "    def list(\n",
        "        self,\n",
        "        config: Optional[RunnableConfig],\n",
        "        *,\n",
        "        filter: Optional[Dict[str, Any]] = None,\n",
        "        before: Optional[RunnableConfig] = None,\n",
        "        limit: Optional[int] = None,\n",
        "    ) -> Iterator[CheckpointTuple]:\n",
        "        \"\"\"List checkpoints from the database.\n",
        "\n",
        "        This method retrieves a list of checkpoint tuples from the MongoDB database based\n",
        "        on the provided config. The checkpoints are ordered by timestamp in descending order.\n",
        "\n",
        "        Args:\n",
        "            config (RunnableConfig): The config to use for listing the checkpoints.\n",
        "            before (Optional[RunnableConfig]): If provided, only checkpoints before the specified timestamp are returned. Defaults to None.\n",
        "            limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None.\n",
        "\n",
        "        Yields:\n",
        "            Iterator[CheckpointTuple]: An iterator of checkpoint tuples.\n",
        "        \"\"\"\n",
        "        query = {}\n",
        "        if config is not None:\n",
        "            query[\"thread_id\"] = config[\"configurable\"][\"thread_id\"]\n",
        "        if filter:\n",
        "            for key, value in filter.items():\n",
        "                query[f\"metadata.{key}\"] = value\n",
        "        if before is not None:\n",
        "            query[\"thread_ts\"] = {\"$lt\": before[\"configurable\"][\"thread_ts\"]}\n",
        "        result = self.collection.find(query).sort(\"thread_ts\", -1).limit(limit)\n",
        "        for doc in result:\n",
        "            yield CheckpointTuple(\n",
        "                {\n",
        "                    \"configurable\": {\n",
        "                        \"thread_id\": doc[\"thread_id\"],\n",
        "                        \"thread_ts\": doc[\"thread_ts\"],\n",
        "                    }\n",
        "                },\n",
        "                self.serde.loads(doc[\"checkpoint\"]),\n",
        "                self.serde.loads(doc[\"metadata\"]),\n",
        "                (\n",
        "                    {\n",
        "                        \"configurable\": {\n",
        "                            \"thread_id\": doc[\"thread_id\"],\n",
        "                            \"thread_ts\": doc[\"parent_ts\"],\n",
        "                        }\n",
        "                    }\n",
        "                    if doc.get(\"parent_ts\")\n",
        "                    else None\n",
        "                ),\n",
        "            )\n",
        "\n",
        "    def put(\n",
        "        self,\n",
        "        config: RunnableConfig,\n",
        "        checkpoint: Checkpoint,\n",
        "        metadata: CheckpointMetadata,\n",
        "    ) -> RunnableConfig:\n",
        "        \"\"\"Save a checkpoint to the database.\n",
        "\n",
        "        This method saves a checkpoint to the MongoDB database. The checkpoint is associated\n",
        "        with the provided config and its parent config (if any).\n",
        "\n",
        "        Args:\n",
        "            config (RunnableConfig): The config to associate with the checkpoint.\n",
        "            checkpoint (Checkpoint): The checkpoint to save.\n",
        "            metadata (Optional[dict[str, Any]]): Additional metadata to save with the checkpoint. Defaults to None.\n",
        "\n",
        "        Returns:\n",
        "            RunnableConfig: The updated config containing the saved checkpoint's timestamp.\n",
        "        \"\"\"\n",
        "        doc = {\n",
        "            \"thread_id\": config[\"configurable\"][\"thread_id\"],\n",
        "            \"thread_ts\": checkpoint[\"id\"],\n",
        "            \"checkpoint\": self.serde.dumps(checkpoint),\n",
        "            \"metadata\": self.serde.dumps(metadata),\n",
        "        }\n",
        "        if config[\"configurable\"].get(\"thread_ts\"):\n",
        "            doc[\"parent_ts\"] = config[\"configurable\"][\"thread_ts\"]\n",
        "        self.collection.insert_one(doc)\n",
        "        return {\n",
        "            \"configurable\": {\n",
        "                \"thread_id\": config[\"configurable\"][\"thread_id\"],\n",
        "                \"thread_ts\": checkpoint[\"id\"],\n",
        "            }\n",
        "        }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## MongoDB connection"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [],
      "source": [
        "MONGO_URI = \"mongodb://localhost:27017/\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Basic example using graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "StateSnapshot(values=4, next=(), config={'configurable': {'thread_id': '123'}}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, created_at='2024-07-09T15:56:06.885848+00:00', parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09c1-6c26-8000-b9e1d26417ff'}})"
            ]
          },
          "execution_count": 3,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "from langgraph.graph import StateGraph, START, END\n",
        "\n",
        "checkpointer = MongoDBSaver(\n",
        "    MongoClient(MONGO_URI), \"checkpoints_db\", \"checkpoints_collection\"\n",
        ")\n",
        "builder = StateGraph(int)\n",
        "builder.add_node(\"add_one\", lambda x: x + 1)\n",
        "builder.add_edge(START, \"add_one\")\n",
        "builder.add_edge(\"add_one\", END)\n",
        "graph = builder.compile(checkpointer=checkpointer)\n",
        "config = {\"configurable\": {\"thread_id\": \"123\"}}\n",
        "graph.get_state(config)\n",
        "result = graph.invoke(3, config)\n",
        "graph.get_state(config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "4"
            ]
          },
          "execution_count": 4,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'v': 1,\n",
              " 'ts': '2024-07-09T15:56:06.885848+00:00',\n",
              " 'id': '1ef3e0bc-09d1-6a75-8001-8f750e9a0782',\n",
              " 'channel_values': {'__root__': 4, 'add_one': 'add_one'},\n",
              " 'channel_versions': {'__start__': 2,\n",
              "  '__root__': 3,\n",
              "  'start:add_one': 3,\n",
              "  'add_one': 3},\n",
              " 'versions_seen': {'__start__': {'__start__': 1},\n",
              "  'add_one': {'start:add_one': 2}},\n",
              " 'pending_sends': []}"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "checkpointer.get(config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09d1-6a75-8001-8f750e9a0782'}}, checkpoint={'v': 1, 'ts': '2024-07-09T15:56:06.885848+00:00', 'id': '1ef3e0bc-09d1-6a75-8001-8f750e9a0782', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 2, '__root__': 3, 'start:add_one': 3, 'add_one': 3}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09c1-6c26-8000-b9e1d26417ff'}})\n",
            "CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09c1-6c26-8000-b9e1d26417ff'}}, checkpoint={'v': 1, 'ts': '2024-07-09T15:56:06.878338+00:00', 'id': '1ef3e0bc-09c1-6c26-8000-b9e1d26417ff', 'channel_values': {'__root__': 3, 'start:add_one': '__start__'}, 'channel_versions': {'__start__': 2, '__root__': 2, 'start:add_one': 2}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 0, 'writes': None}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09bc-6e04-bfff-5342ac1ccc10'}})\n",
            "CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3e0bc-09bc-6e04-bfff-5342ac1ccc10'}}, checkpoint={'v': 1, 'ts': '2024-07-09T15:56:06.877337+00:00', 'id': '1ef3e0bc-09bc-6e04-bfff-5342ac1ccc10', 'channel_values': {'__start__': 3}, 'channel_versions': {'__start__': 1}, 'versions_seen': {}, 'pending_sends': []}, metadata={'source': 'input', 'step': -1, 'writes': 3}, parent_config=None)\n"
          ]
        }
      ],
      "source": [
        "list = checkpointer.list(config, limit=3)\n",
        "for item in list:\n",
        "    print(item)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "CheckpointTuple(config={'configurable': {'thread_id': '123'}}, checkpoint={'v': 1, 'ts': '2024-07-09T13:22:19.610402+00:00', 'id': '1ef3df64-4ba9-6b58-8001-ab084cc01a30', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 2, '__root__': 3, 'start:add_one': 3, 'add_one': 3}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 1, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3df64-4ba2-660c-8000-569999697ff3'}})"
            ]
          },
          "execution_count": 13,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "checkpointer.get_tuple(config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setup environment"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {},
      "outputs": [],
      "source": [
        "import getpass\n",
        "import os\n",
        "\n",
        "\n",
        "def _set_env(var: str):\n",
        "    if not os.environ.get(var):\n",
        "        os.environ[var] = getpass.getpass(f\"{var}: \")\n",
        "\n",
        "\n",
        "_set_env(\"OPENAI_API_KEY\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Setup model and tools for the graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "%pip install langchain_openai"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {},
      "outputs": [],
      "source": [
        "from typing import Literal\n",
        "from langchain_core.runnables import ConfigurableField\n",
        "from langchain_core.tools import tool\n",
        "from langchain_openai import ChatOpenAI\n",
        "from langgraph.prebuilt import create_react_agent\n",
        "\n",
        "\n",
        "@tool\n",
        "def get_weather(city: Literal[\"nyc\", \"sf\"]):\n",
        "    \"\"\"Use this to get weather information.\"\"\"\n",
        "    if city == \"nyc\":\n",
        "        return \"It might be cloudy in nyc\"\n",
        "    elif city == \"sf\":\n",
        "        return \"It's always sunny in sf\"\n",
        "    else:\n",
        "        raise AssertionError(\"Unknown city\")\n",
        "\n",
        "\n",
        "tools = [get_weather]\n",
        "model = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {},
      "outputs": [],
      "source": [
        "graph = create_react_agent(model, tools=tools, checkpointer=checkpointer)\n",
        "config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
        "res = graph.invoke({\"messages\": [(\"human\", \"what's the weather in sf\")]}, config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'messages': [HumanMessage(content=\"what's the weather in sf\", id='a624d383-13c6-499c-8f03-31ed11fa0cfb'),\n",
              "  AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_wapm4s91KQUQqE9y1L53QmmE', 'function': {'arguments': '{\"city\":\"sf\"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 58, 'total_tokens': 72}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-614bc54a-ad37-4f17-9047-b80752bdf66e-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'sf'}, 'id': 'call_wapm4s91KQUQqE9y1L53QmmE'}]),\n",
              "  ToolMessage(content=\"It's always sunny in sf\", name='get_weather', id='e58dd97d-b50d-4b0a-9492-7155106c975a', tool_call_id='call_wapm4s91KQUQqE9y1L53QmmE'),\n",
              "  AIMessage(content='The weather in San Francisco is always sunny!', response_metadata={'token_usage': {'completion_tokens': 10, 'prompt_tokens': 86, 'total_tokens': 96}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-e9984eca-f132-46d0-94ba-41d8ba5b7046-0')]}"
            ]
          },
          "execution_count": 17,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "res"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "CheckpointTuple(config={'configurable': {'thread_id': '1'}}, checkpoint={'v': 1, 'ts': '2024-07-09T13:22:49.794047+00:00', 'id': '1ef3df65-6b84-63fd-8003-888bcef289e3', 'channel_values': {'messages': [HumanMessage(content=\"what's the weather in sf\", id='a624d383-13c6-499c-8f03-31ed11fa0cfb'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_wapm4s91KQUQqE9y1L53QmmE', 'function': {'arguments': '{\"city\":\"sf\"}', 'name': 'get_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 14, 'prompt_tokens': 58, 'total_tokens': 72}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-614bc54a-ad37-4f17-9047-b80752bdf66e-0', tool_calls=[{'name': 'get_weather', 'args': {'city': 'sf'}, 'id': 'call_wapm4s91KQUQqE9y1L53QmmE'}]), ToolMessage(content=\"It's always sunny in sf\", name='get_weather', id='e58dd97d-b50d-4b0a-9492-7155106c975a', tool_call_id='call_wapm4s91KQUQqE9y1L53QmmE'), AIMessage(content='The weather in San Francisco is always sunny!', response_metadata={'token_usage': {'completion_tokens': 10, 'prompt_tokens': 86, 'total_tokens': 96}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-e9984eca-f132-46d0-94ba-41d8ba5b7046-0')], 'agent': 'agent'}, 'channel_versions': {'__start__': 2, 'messages': 5, 'start:agent': 3, 'agent': 5, 'branch:agent:should_continue:tools': 4, 'tools': 5}, 'versions_seen': {'__start__': {'__start__': 1}, 'agent': {'start:agent': 3, 'tools': 4}, 'tools': {'branch:agent:should_continue:tools': 3}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 3, 'writes': {'agent': {'messages': [AIMessage(content='The weather in San Francisco is always sunny!', response_metadata={'token_usage': {'completion_tokens': 10, 'prompt_tokens': 86, 'total_tokens': 96}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-e9984eca-f132-46d0-94ba-41d8ba5b7046-0')]}}}, parent_config={'configurable': {'thread_id': '1', 'thread_ts': '1ef3df65-6252-6b30-8002-2c9e1e68364b'}})"
            ]
          },
          "execution_count": 18,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "checkpointer.get_tuple(config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Checkpoints saved in MongoDB"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'_id': ObjectId('668d398bb975d3e766de42ce'), 'thread_id': '123', 'thread_ts': '1ef3df64-4b98-68b4-bfff-592f97570cf6', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-09T13:22:19.603371+00:00\", \"id\": \"1ef3df64-4b98-68b4-bfff-592f97570cf6\", \"channel_values\": {\"__start__\": 3}, \"channel_versions\": {\"__start__\": 1}, \"versions_seen\": {}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"input\", \"step\": -1, \"writes\": 3}'}\n",
            "{'_id': ObjectId('668d398bb975d3e766de42cf'), 'thread_id': '123', 'thread_ts': '1ef3df64-4ba2-660c-8000-569999697ff3', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-09T13:22:19.607399+00:00\", \"id\": \"1ef3df64-4ba2-660c-8000-569999697ff3\", \"channel_values\": {\"__root__\": 3, \"start:add_one\": \"__start__\"}, \"channel_versions\": {\"__start__\": 2, \"__root__\": 2, \"start:add_one\": 2}, \"versions_seen\": {\"__start__\": {\"__start__\": 1}, \"add_one\": {}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 0, \"writes\": null}', 'parent_ts': '1ef3df64-4b98-68b4-bfff-592f97570cf6'}\n",
            "{'_id': ObjectId('668d398bb975d3e766de42d0'), 'thread_id': '123', 'thread_ts': '1ef3df64-4ba9-6b58-8001-ab084cc01a30', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-09T13:22:19.610402+00:00\", \"id\": \"1ef3df64-4ba9-6b58-8001-ab084cc01a30\", \"channel_values\": {\"__root__\": 4, \"add_one\": \"add_one\"}, \"channel_versions\": {\"__start__\": 2, \"__root__\": 3, \"start:add_one\": 3, \"add_one\": 3}, \"versions_seen\": {\"__start__\": {\"__start__\": 1}, \"add_one\": {\"start:add_one\": 2}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 1, \"writes\": {\"add_one\": 4}}', 'parent_ts': '1ef3df64-4ba2-660c-8000-569999697ff3'}\n",
            "{'_id': ObjectId('668d39a7b975d3e766de42d1'), 'thread_id': '1', 'thread_ts': '1ef3df65-585c-6abb-bfff-634fac15b6b3', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-09T13:22:47.785541+00:00\", \"id\": \"1ef3df65-585c-6abb-bfff-634fac15b6b3\", \"channel_values\": {\"messages\": [], \"__start__\": {\"messages\": [[\"human\", \"what\\'s the weather in sf\"]]}}, \"channel_versions\": {\"__start__\": 1}, \"versions_seen\": {}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"input\", \"step\": -1, \"writes\": {\"messages\": [[\"human\", \"what\\'s the weather in sf\"]]}}'}\n",
            "{'_id': ObjectId('668d39a7b975d3e766de42d2'), 'thread_id': '1', 'thread_ts': '1ef3df65-5863-6fbc-8000-d45480f1ccc0', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-09T13:22:47.788537+00:00\", \"id\": \"1ef3df65-5863-6fbc-8000-d45480f1ccc0\", \"channel_values\": {\"messages\": [{\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"HumanMessage\"], \"kwargs\": {\"content\": \"what\\'s the weather in sf\", \"type\": \"human\", \"id\": \"a624d383-13c6-499c-8f03-31ed11fa0cfb\"}}], \"start:agent\": \"__start__\"}, \"channel_versions\": {\"__start__\": 2, \"messages\": 2, \"start:agent\": 2}, \"versions_seen\": {\"__start__\": {\"__start__\": 1}, \"agent\": {}, \"tools\": {}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 0, \"writes\": null}', 'parent_ts': '1ef3df65-585c-6abb-bfff-634fac15b6b3'}\n",
            "{'_id': ObjectId('668d39a8b975d3e766de42d3'), 'thread_id': '1', 'thread_ts': '1ef3df65-6248-6ee8-8001-d5eef3fad087', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-09T13:22:48.826032+00:00\", \"id\": \"1ef3df65-6248-6ee8-8001-d5eef3fad087\", \"channel_values\": {\"messages\": [{\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"HumanMessage\"], \"kwargs\": {\"content\": \"what\\'s the weather in sf\", \"type\": \"human\", \"id\": \"a624d383-13c6-499c-8f03-31ed11fa0cfb\"}}, {\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"AIMessage\"], \"kwargs\": {\"content\": \"\", \"additional_kwargs\": {\"tool_calls\": [{\"id\": \"call_wapm4s91KQUQqE9y1L53QmmE\", \"function\": {\"arguments\": \"{\\\\\"city\\\\\":\\\\\"sf\\\\\"}\", \"name\": \"get_weather\"}, \"type\": \"function\"}]}, \"response_metadata\": {\"token_usage\": {\"completion_tokens\": 14, \"prompt_tokens\": 58, \"total_tokens\": 72}, \"model_name\": \"gpt-3.5-turbo\", \"system_fingerprint\": null, \"finish_reason\": \"tool_calls\", \"logprobs\": null}, \"type\": \"ai\", \"id\": \"run-614bc54a-ad37-4f17-9047-b80752bdf66e-0\", \"tool_calls\": [{\"name\": \"get_weather\", \"args\": {\"city\": \"sf\"}, \"id\": \"call_wapm4s91KQUQqE9y1L53QmmE\"}], \"invalid_tool_calls\": []}}], \"agent\": \"agent\", \"branch:agent:should_continue:tools\": \"agent\"}, \"channel_versions\": {\"__start__\": 2, \"messages\": 3, \"start:agent\": 3, \"agent\": 3, \"branch:agent:should_continue:tools\": 3}, \"versions_seen\": {\"__start__\": {\"__start__\": 1}, \"agent\": {\"start:agent\": 2}, \"tools\": {}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 1, \"writes\": {\"agent\": {\"messages\": [{\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"AIMessage\"], \"kwargs\": {\"content\": \"\", \"additional_kwargs\": {\"tool_calls\": [{\"id\": \"call_wapm4s91KQUQqE9y1L53QmmE\", \"function\": {\"arguments\": \"{\\\\\"city\\\\\":\\\\\"sf\\\\\"}\", \"name\": \"get_weather\"}, \"type\": \"function\"}]}, \"response_metadata\": {\"token_usage\": {\"completion_tokens\": 14, \"prompt_tokens\": 58, \"total_tokens\": 72}, \"model_name\": \"gpt-3.5-turbo\", \"system_fingerprint\": null, \"finish_reason\": \"tool_calls\", \"logprobs\": null}, \"type\": \"ai\", \"id\": \"run-614bc54a-ad37-4f17-9047-b80752bdf66e-0\", \"tool_calls\": [{\"name\": \"get_weather\", \"args\": {\"city\": \"sf\"}, \"id\": \"call_wapm4s91KQUQqE9y1L53QmmE\"}], \"invalid_tool_calls\": []}}]}}}', 'parent_ts': '1ef3df65-5863-6fbc-8000-d45480f1ccc0'}\n",
            "{'_id': ObjectId('668d39a8b975d3e766de42d4'), 'thread_id': '1', 'thread_ts': '1ef3df65-6252-6b30-8002-2c9e1e68364b', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-09T13:22:48.830033+00:00\", \"id\": \"1ef3df65-6252-6b30-8002-2c9e1e68364b\", \"channel_values\": {\"messages\": [{\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"HumanMessage\"], \"kwargs\": {\"content\": \"what\\'s the weather in sf\", \"type\": \"human\", \"id\": \"a624d383-13c6-499c-8f03-31ed11fa0cfb\"}}, {\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"AIMessage\"], \"kwargs\": {\"content\": \"\", \"additional_kwargs\": {\"tool_calls\": [{\"id\": \"call_wapm4s91KQUQqE9y1L53QmmE\", \"function\": {\"arguments\": \"{\\\\\"city\\\\\":\\\\\"sf\\\\\"}\", \"name\": \"get_weather\"}, \"type\": \"function\"}]}, \"response_metadata\": {\"token_usage\": {\"completion_tokens\": 14, \"prompt_tokens\": 58, \"total_tokens\": 72}, \"model_name\": \"gpt-3.5-turbo\", \"system_fingerprint\": null, \"finish_reason\": \"tool_calls\", \"logprobs\": null}, \"type\": \"ai\", \"id\": \"run-614bc54a-ad37-4f17-9047-b80752bdf66e-0\", \"tool_calls\": [{\"name\": \"get_weather\", \"args\": {\"city\": \"sf\"}, \"id\": \"call_wapm4s91KQUQqE9y1L53QmmE\"}], \"invalid_tool_calls\": []}}, {\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"ToolMessage\"], \"kwargs\": {\"content\": \"It\\'s always sunny in sf\", \"type\": \"tool\", \"name\": \"get_weather\", \"id\": \"e58dd97d-b50d-4b0a-9492-7155106c975a\", \"tool_call_id\": \"call_wapm4s91KQUQqE9y1L53QmmE\"}}], \"tools\": \"tools\"}, \"channel_versions\": {\"__start__\": 2, \"messages\": 4, \"start:agent\": 3, \"agent\": 4, \"branch:agent:should_continue:tools\": 4, \"tools\": 4}, \"versions_seen\": {\"__start__\": {\"__start__\": 1}, \"agent\": {\"start:agent\": 2}, \"tools\": {\"branch:agent:should_continue:tools\": 3}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 2, \"writes\": {\"tools\": {\"messages\": [{\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"ToolMessage\"], \"kwargs\": {\"content\": \"It\\'s always sunny in sf\", \"type\": \"tool\", \"name\": \"get_weather\", \"id\": \"e58dd97d-b50d-4b0a-9492-7155106c975a\", \"tool_call_id\": \"call_wapm4s91KQUQqE9y1L53QmmE\"}}]}}}', 'parent_ts': '1ef3df65-6248-6ee8-8001-d5eef3fad087'}\n",
            "{'_id': ObjectId('668d39a9b975d3e766de42d5'), 'thread_id': '1', 'thread_ts': '1ef3df65-6b84-63fd-8003-888bcef289e3', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-09T13:22:49.794047+00:00\", \"id\": \"1ef3df65-6b84-63fd-8003-888bcef289e3\", \"channel_values\": {\"messages\": [{\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"HumanMessage\"], \"kwargs\": {\"content\": \"what\\'s the weather in sf\", \"type\": \"human\", \"id\": \"a624d383-13c6-499c-8f03-31ed11fa0cfb\"}}, {\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"AIMessage\"], \"kwargs\": {\"content\": \"\", \"additional_kwargs\": {\"tool_calls\": [{\"id\": \"call_wapm4s91KQUQqE9y1L53QmmE\", \"function\": {\"arguments\": \"{\\\\\"city\\\\\":\\\\\"sf\\\\\"}\", \"name\": \"get_weather\"}, \"type\": \"function\"}]}, \"response_metadata\": {\"token_usage\": {\"completion_tokens\": 14, \"prompt_tokens\": 58, \"total_tokens\": 72}, \"model_name\": \"gpt-3.5-turbo\", \"system_fingerprint\": null, \"finish_reason\": \"tool_calls\", \"logprobs\": null}, \"type\": \"ai\", \"id\": \"run-614bc54a-ad37-4f17-9047-b80752bdf66e-0\", \"tool_calls\": [{\"name\": \"get_weather\", \"args\": {\"city\": \"sf\"}, \"id\": \"call_wapm4s91KQUQqE9y1L53QmmE\"}], \"invalid_tool_calls\": []}}, {\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"ToolMessage\"], \"kwargs\": {\"content\": \"It\\'s always sunny in sf\", \"type\": \"tool\", \"name\": \"get_weather\", \"id\": \"e58dd97d-b50d-4b0a-9492-7155106c975a\", \"tool_call_id\": \"call_wapm4s91KQUQqE9y1L53QmmE\"}}, {\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"AIMessage\"], \"kwargs\": {\"content\": \"The weather in San Francisco is always sunny!\", \"response_metadata\": {\"token_usage\": {\"completion_tokens\": 10, \"prompt_tokens\": 86, \"total_tokens\": 96}, \"model_name\": \"gpt-3.5-turbo\", \"system_fingerprint\": null, \"finish_reason\": \"stop\", \"logprobs\": null}, \"type\": \"ai\", \"id\": \"run-e9984eca-f132-46d0-94ba-41d8ba5b7046-0\", \"tool_calls\": [], \"invalid_tool_calls\": []}}], \"agent\": \"agent\"}, \"channel_versions\": {\"__start__\": 2, \"messages\": 5, \"start:agent\": 3, \"agent\": 5, \"branch:agent:should_continue:tools\": 4, \"tools\": 5}, \"versions_seen\": {\"__start__\": {\"__start__\": 1}, \"agent\": {\"start:agent\": 3, \"tools\": 4}, \"tools\": {\"branch:agent:should_continue:tools\": 3}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 3, \"writes\": {\"agent\": {\"messages\": [{\"lc\": 1, \"type\": \"constructor\", \"id\": [\"langchain\", \"schema\", \"messages\", \"AIMessage\"], \"kwargs\": {\"content\": \"The weather in San Francisco is always sunny!\", \"response_metadata\": {\"token_usage\": {\"completion_tokens\": 10, \"prompt_tokens\": 86, \"total_tokens\": 96}, \"model_name\": \"gpt-3.5-turbo\", \"system_fingerprint\": null, \"finish_reason\": \"stop\", \"logprobs\": null}, \"type\": \"ai\", \"id\": \"run-e9984eca-f132-46d0-94ba-41d8ba5b7046-0\", \"tool_calls\": [], \"invalid_tool_calls\": []}}]}}}', 'parent_ts': '1ef3df65-6252-6b30-8002-2c9e1e68364b'}\n"
          ]
        }
      ],
      "source": [
        "client = MongoClient(MONGO_URI)\n",
        "database = client[\"checkpoints_db\"]\n",
        "collection = database[\"checkpoints_collection\"]\n",
        "\n",
        "for doc in collection.find():\n",
        "    print(doc)\n",
        "\n",
        "# The checkpoints from both the examples have been saved in the database."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Asynchronous implementation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Async package for MongoDB\n",
        "%pip install motor"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [],
      "source": [
        "import pickle\n",
        "from contextlib import AbstractContextManager\n",
        "from types import TracebackType\n",
        "from typing import Any, Dict, Optional, AsyncIterator\n",
        "\n",
        "from langchain_core.runnables import RunnableConfig\n",
        "from typing_extensions import Self\n",
        "\n",
        "from langgraph.checkpoint.base import (\n",
        "    BaseCheckpointSaver,\n",
        "    Checkpoint,\n",
        "    CheckpointMetadata,\n",
        "    CheckpointTuple,\n",
        "    SerializerProtocol,\n",
        ")\n",
        "from langgraph.serde.jsonplus import JsonPlusSerializer\n",
        "from motor.motor_asyncio import AsyncIOMotorClient\n",
        "\n",
        "\n",
        "class JsonPlusSerializerCompat(JsonPlusSerializer):\n",
        "    \"\"\"A serializer that supports loading pickled checkpoints for backwards compatibility.\n",
        "\n",
        "    This serializer extends the JsonPlusSerializer and adds support for loading pickled\n",
        "    checkpoints. If the input data starts with b\"\\x80\" and ends with b\".\", it is treated\n",
        "    as a pickled checkpoint and loaded using pickle.loads(). Otherwise, the default\n",
        "    JsonPlusSerializer behavior is used.\n",
        "\n",
        "    Examples:\n",
        "        >>> import pickle\n",
        "        >>> from langgraph.checkpoint.sqlite import JsonPlusSerializerCompat\n",
        "        >>>\n",
        "        >>> serializer = JsonPlusSerializerCompat()\n",
        "        >>> pickled_data = pickle.dumps({\"key\": \"value\"})\n",
        "        >>> loaded_data = serializer.loads(pickled_data)\n",
        "        >>> print(loaded_data)  # Output: {\"key\": \"value\"}\n",
        "        >>>\n",
        "        >>> json_data = '{\"key\": \"value\"}'.encode(\"utf-8\")\n",
        "        >>> loaded_data = serializer.loads(json_data)\n",
        "        >>> print(loaded_data)  # Output: {\"key\": \"value\"}\n",
        "    \"\"\"\n",
        "\n",
        "    def loads(self, data: bytes) -> Any:\n",
        "        if data.startswith(b\"\\x80\") and data.endswith(b\".\"):\n",
        "            return pickle.loads(data)\n",
        "        return super().loads(data)\n",
        "\n",
        "\n",
        "class MongoDBSaver(AbstractContextManager, BaseCheckpointSaver):\n",
        "    \"\"\"A checkpoint saver that stores checkpoints in a MongoDB database.\n",
        "\n",
        "    Args:\n",
        "        client (AsyncIOMotorClient): The Async MongoDB client.\n",
        "        db_name (str): The name of the database to use.\n",
        "        collection_name (str): The name of the collection to use.\n",
        "        serde (Optional[SerializerProtocol]): The serializer to use for serializing and deserializing checkpoints. Defaults to JsonPlusSerializerCompat.\n",
        "\n",
        "    Examples:\n",
        "\n",
        "        >>> from motor.motor_asyncio import AsyncIOMotorClient\n",
        "        >>> from langgraph.checkpoint.mongodb import MongoDBSaver\n",
        "        >>> from langgraph.graph import StateGraph\n",
        "        >>>\n",
        "        >>> builder = StateGraph(int)\n",
        "        >>> builder.add_node(\"add_one\", lambda x: x + 1)\n",
        "        >>> builder.set_entry_point(\"add_one\")\n",
        "        >>> builder.set_finish_point(\"add_one\")\n",
        "        >>> client = AsyncIOMotorClient(\"mongodb://localhost:27017/\")\n",
        "        >>> memory = MongoDBSaver(client, \"checkpoints\", \"checkpoints\")\n",
        "        >>> graph = builder.compile(checkpointer=memory)\n",
        "        >>> config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
        "        >>> result = graph.ainvoke(3, config)\n",
        "    \"\"\"\n",
        "\n",
        "    serde = JsonPlusSerializerCompat()\n",
        "\n",
        "    client: AsyncIOMotorClient\n",
        "    db_name: str\n",
        "    collection_name: str\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        client: AsyncIOMotorClient,\n",
        "        db_name: str,\n",
        "        collection_name: str,\n",
        "        *,\n",
        "        serde: Optional[SerializerProtocol] = None,\n",
        "    ) -> None:\n",
        "        super().__init__(serde=serde)\n",
        "        self.client = client\n",
        "        self.db_name = db_name\n",
        "        self.collection_name = collection_name\n",
        "        self.collection = client[db_name][collection_name]\n",
        "\n",
        "    def __enter__(self) -> Self:\n",
        "        return self\n",
        "\n",
        "    def __exit__(\n",
        "        self,\n",
        "        __exc_type: Optional[type[BaseException]],\n",
        "        __exc_value: Optional[BaseException],\n",
        "        __traceback: Optional[TracebackType],\n",
        "    ) -> Optional[bool]:\n",
        "        return True\n",
        "\n",
        "    async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:\n",
        "        \"\"\"Get a checkpoint tuple from the database.\n",
        "\n",
        "        This method retrieves a checkpoint tuple from the MongoDB database based on the\n",
        "        provided config. If the config contains a \"thread_ts\" key, the checkpoint with\n",
        "        the matching thread ID and timestamp is retrieved. Otherwise, the latest checkpoint\n",
        "        for the given thread ID is retrieved.\n",
        "\n",
        "        Args:\n",
        "            config (RunnableConfig): The config to use for retrieving the checkpoint.\n",
        "\n",
        "        Returns:\n",
        "            Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found.\n",
        "        \"\"\"\n",
        "        if config[\"configurable\"].get(\"thread_ts\"):\n",
        "            query = {\n",
        "                \"thread_id\": config[\"configurable\"][\"thread_id\"],\n",
        "                \"thread_ts\": config[\"configurable\"][\"thread_ts\"],\n",
        "            }\n",
        "        else:\n",
        "            query = {\"thread_id\": config[\"configurable\"][\"thread_id\"]}\n",
        "        result = self.collection.find(query).sort(\"thread_ts\", -1).limit(1)\n",
        "        async for doc in result:\n",
        "            return CheckpointTuple(\n",
        "                config,\n",
        "                self.serde.loads(doc[\"checkpoint\"]),\n",
        "                self.serde.loads(doc[\"metadata\"]),\n",
        "                (\n",
        "                    {\n",
        "                        \"configurable\": {\n",
        "                            \"thread_id\": doc[\"thread_id\"],\n",
        "                            \"thread_ts\": doc[\"parent_ts\"],\n",
        "                        }\n",
        "                    }\n",
        "                    if doc.get(\"parent_ts\")\n",
        "                    else None\n",
        "                ),\n",
        "            )\n",
        "\n",
        "    async def alist(\n",
        "        self,\n",
        "        config: Optional[RunnableConfig],\n",
        "        *,\n",
        "        filter: Optional[Dict[str, Any]] = None,\n",
        "        before: Optional[RunnableConfig] = None,\n",
        "        limit: Optional[int] = None,\n",
        "    ) -> AsyncIterator[CheckpointTuple]:\n",
        "        \"\"\"List checkpoints from the database.\n",
        "\n",
        "        This method retrieves a list of checkpoint tuples from the MongoDB database based\n",
        "        on the provided config. The checkpoints are ordered by timestamp in descending order.\n",
        "\n",
        "        Args:\n",
        "            config (RunnableConfig): The config to use for listing the checkpoints.\n",
        "            before (Optional[RunnableConfig]): If provided, only checkpoints before the specified timestamp are returned. Defaults to None.\n",
        "            limit (Optional[int]): The maximum number of checkpoints to return. Defaults to None.\n",
        "\n",
        "        Yields:\n",
        "            AsyncIterator[CheckpointTuple]: An Async iterator of checkpoint tuples.\n",
        "        \"\"\"\n",
        "        query = {}\n",
        "        if config is not None:\n",
        "            query[\"thread_id\"] = config[\"configurable\"][\"thread_id\"]\n",
        "        if filter:\n",
        "            for key, value in filter.items():\n",
        "                query[f\"metadata.{key}\"] = value\n",
        "        if before is not None:\n",
        "            query[\"thread_ts\"] = {\"$lt\": before[\"configurable\"][\"thread_ts\"]}\n",
        "        result = self.collection.find(query).sort(\"thread_ts\", -1).limit(limit)\n",
        "        if limit is not None:\n",
        "            result = result.limit(limit)\n",
        "        async for doc in result:\n",
        "            yield CheckpointTuple(\n",
        "                {\n",
        "                    \"configurable\": {\n",
        "                        \"thread_id\": doc[\"thread_id\"],\n",
        "                        \"thread_ts\": doc[\"thread_ts\"],\n",
        "                    }\n",
        "                },\n",
        "                self.serde.loads(doc[\"checkpoint\"]),\n",
        "                self.serde.loads(doc[\"metadata\"]),\n",
        "                (\n",
        "                    {\n",
        "                        \"configurable\": {\n",
        "                            \"thread_id\": doc[\"thread_id\"],\n",
        "                            \"thread_ts\": doc[\"parent_ts\"],\n",
        "                        }\n",
        "                    }\n",
        "                    if doc.get(\"parent_ts\")\n",
        "                    else None\n",
        "                ),\n",
        "            )\n",
        "\n",
        "    async def aput(\n",
        "        self,\n",
        "        config: RunnableConfig,\n",
        "        checkpoint: Checkpoint,\n",
        "        metadata: CheckpointMetadata,\n",
        "    ) -> RunnableConfig:\n",
        "        \"\"\"Save a checkpoint to the database.\n",
        "\n",
        "        This method saves a checkpoint to the MongoDB database. The checkpoint is associated\n",
        "        with the provided config and its parent config (if any).\n",
        "\n",
        "        Args:\n",
        "            config (RunnableConfig): The config to associate with the checkpoint.\n",
        "            checkpoint (Checkpoint): The checkpoint to save.\n",
        "            metadata (Optional[dict[str, Any]]): Additional metadata to save with the checkpoint. Defaults to None.\n",
        "\n",
        "        Returns:\n",
        "            RunnableConfig: The updated config containing the saved checkpoint's timestamp.\n",
        "        \"\"\"\n",
        "        doc = {\n",
        "            \"thread_id\": config[\"configurable\"][\"thread_id\"],\n",
        "            \"thread_ts\": checkpoint[\"id\"],\n",
        "            \"checkpoint\": self.serde.dumps(checkpoint),\n",
        "            \"metadata\": self.serde.dumps(metadata),\n",
        "        }\n",
        "        if config[\"configurable\"].get(\"thread_ts\"):\n",
        "            doc[\"parent_ts\"] = config[\"configurable\"][\"thread_ts\"]\n",
        "        await self.collection.insert_one(doc)\n",
        "        return {\n",
        "            \"configurable\": {\n",
        "                \"thread_id\": config[\"configurable\"][\"thread_id\"],\n",
        "                \"thread_ts\": checkpoint[\"id\"],\n",
        "            }\n",
        "        }"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Example with basic graph"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {},
      "outputs": [],
      "source": [
        "from langgraph.graph import StateGraph, START\n",
        "\n",
        "checkpointer = MongoDBSaver(\n",
        "    AsyncIOMotorClient(MONGO_URI), \"checkpoints_db\", \"checkpoints_collection\"\n",
        ")\n",
        "builder = StateGraph(int)\n",
        "builder.add_node(\"add_one\", lambda x: x + 1)\n",
        "builder.add_edge(START, \"add_one\")\n",
        "builder.add_edge(\"add_one\", END)\n",
        "graph = builder.compile(checkpointer=checkpointer)\n",
        "config = {\"configurable\": {\"thread_id\": \"123\"}}\n",
        "res = await graph.ainvoke(3, config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "4"
            ]
          },
          "execution_count": 6,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "res"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'v': 1,\n",
              " 'ts': '2024-07-10T11:34:28.485660+00:00',\n",
              " 'id': '1ef3eb05-e0d1-651b-8004-15f129f5f4fb',\n",
              " 'channel_values': {'__root__': 4, 'add_one': 'add_one'},\n",
              " 'channel_versions': {'__start__': 5,\n",
              "  '__root__': 6,\n",
              "  'start:add_one': 6,\n",
              "  'add_one': 6},\n",
              " 'versions_seen': {'__start__': {'__start__': 4},\n",
              "  'add_one': {'start:add_one': 5}},\n",
              " 'pending_sends': []}"
            ]
          },
          "execution_count": 9,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "await checkpointer.aget(config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {},
      "outputs": [
        {
          "data": {
            "text/plain": [
              "CheckpointTuple(config={'configurable': {'thread_id': '123'}}, checkpoint={'v': 1, 'ts': '2024-07-10T11:34:28.485660+00:00', 'id': '1ef3eb05-e0d1-651b-8004-15f129f5f4fb', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 5, '__root__': 6, 'start:add_one': 6, 'add_one': 6}, 'versions_seen': {'__start__': {'__start__': 4}, 'add_one': {'start:add_one': 5}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 4, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1'}})"
            ]
          },
          "execution_count": 10,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "await checkpointer.aget_tuple(config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0d1-651b-8004-15f129f5f4fb'}}, checkpoint={'v': 1, 'ts': '2024-07-10T11:34:28.485660+00:00', 'id': '1ef3eb05-e0d1-651b-8004-15f129f5f4fb', 'channel_values': {'__root__': 4, 'add_one': 'add_one'}, 'channel_versions': {'__start__': 5, '__root__': 6, 'start:add_one': 6, 'add_one': 6}, 'versions_seen': {'__start__': {'__start__': 4}, 'add_one': {'start:add_one': 5}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 4, 'writes': {'add_one': 4}}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1'}})\n",
            "CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1'}}, checkpoint={'v': 1, 'ts': '2024-07-10T11:34:28.477660+00:00', 'id': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1', 'channel_values': {'__root__': 3, 'start:add_one': '__start__'}, 'channel_versions': {'__start__': 5, '__root__': 5, 'start:add_one': 5, 'add_one': 4}, 'versions_seen': {'__start__': {'__start__': 4}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'loop', 'step': 3, 'writes': None}, parent_config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0bb-659e-8002-de83b4764141'}})\n",
            "CheckpointTuple(config={'configurable': {'thread_id': '123', 'thread_ts': '1ef3eb05-e0bb-659e-8002-de83b4764141'}}, checkpoint={'v': 1, 'ts': '2024-07-10T11:34:28.476662+00:00', 'id': '1ef3eb05-e0bb-659e-8002-de83b4764141', 'channel_values': {'__root__': 4, '__start__': 3}, 'channel_versions': {'__start__': 4, '__root__': 3, 'start:add_one': 3, 'add_one': 4}, 'versions_seen': {'__start__': {'__start__': 1}, 'add_one': {'start:add_one': 2}}, 'pending_sends': []}, metadata={'source': 'input', 'step': 2, 'writes': 3}, parent_config=None)\n"
          ]
        }
      ],
      "source": [
        "list = checkpointer.alist(config, limit=3)\n",
        "async for item in list:\n",
        "    print(item)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Checkpoints saved in MongoDB"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "{'_id': ObjectId('668e57930f55bbe62f358531'), 'thread_id': '123', 'thread_ts': '1ef3ea0c-18a5-67a6-bfff-0d85b77e4a09', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-10T09:42:43.453328+00:00\", \"id\": \"1ef3ea0c-18a5-67a6-bfff-0d85b77e4a09\", \"channel_values\": {\"__start__\": 3}, \"channel_versions\": {\"__start__\": 1}, \"versions_seen\": {}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"input\", \"step\": -1, \"writes\": 3}'}\n",
            "{'_id': ObjectId('668e57930f55bbe62f358532'), 'thread_id': '123', 'thread_ts': '1ef3ea0c-18a7-6ea3-8000-9a52ba553d0c', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-10T09:42:43.454326+00:00\", \"id\": \"1ef3ea0c-18a7-6ea3-8000-9a52ba553d0c\", \"channel_values\": {\"__root__\": 3, \"start:add_one\": \"__start__\"}, \"channel_versions\": {\"__start__\": 2, \"__root__\": 2, \"start:add_one\": 2}, \"versions_seen\": {\"__start__\": {\"__start__\": 1}, \"add_one\": {}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 0, \"writes\": null}', 'parent_ts': '1ef3ea0c-18a5-67a6-bfff-0d85b77e4a09'}\n",
            "{'_id': ObjectId('668e57930f55bbe62f358533'), 'thread_id': '123', 'thread_ts': '1ef3ea0c-18bc-6b54-8001-ef5781939492', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-10T09:42:43.462843+00:00\", \"id\": \"1ef3ea0c-18bc-6b54-8001-ef5781939492\", \"channel_values\": {\"__root__\": 4, \"add_one\": \"add_one\"}, \"channel_versions\": {\"__start__\": 2, \"__root__\": 3, \"start:add_one\": 3, \"add_one\": 3}, \"versions_seen\": {\"__start__\": {\"__start__\": 1}, \"add_one\": {\"start:add_one\": 2}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 1, \"writes\": {\"add_one\": 4}}', 'parent_ts': '1ef3ea0c-18a7-6ea3-8000-9a52ba553d0c'}\n",
            "{'_id': ObjectId('668e71c4171972a41a226373'), 'thread_id': '123', 'thread_ts': '1ef3eb05-e0bb-659e-8002-de83b4764141', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-10T11:34:28.476662+00:00\", \"id\": \"1ef3eb05-e0bb-659e-8002-de83b4764141\", \"channel_values\": {\"__root__\": 4, \"__start__\": 3}, \"channel_versions\": {\"__start__\": 4, \"__root__\": 3, \"start:add_one\": 3, \"add_one\": 4}, \"versions_seen\": {\"__start__\": {\"__start__\": 1}, \"add_one\": {\"start:add_one\": 2}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"input\", \"step\": 2, \"writes\": 3}'}\n",
            "{'_id': ObjectId('668e71c4171972a41a226374'), 'thread_id': '123', 'thread_ts': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-10T11:34:28.477660+00:00\", \"id\": \"1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1\", \"channel_values\": {\"__root__\": 3, \"start:add_one\": \"__start__\"}, \"channel_versions\": {\"__start__\": 5, \"__root__\": 5, \"start:add_one\": 5, \"add_one\": 4}, \"versions_seen\": {\"__start__\": {\"__start__\": 4}, \"add_one\": {\"start:add_one\": 2}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 3, \"writes\": null}', 'parent_ts': '1ef3eb05-e0bb-659e-8002-de83b4764141'}\n",
            "{'_id': ObjectId('668e71c4171972a41a226375'), 'thread_id': '123', 'thread_ts': '1ef3eb05-e0d1-651b-8004-15f129f5f4fb', 'checkpoint': b'{\"v\": 1, \"ts\": \"2024-07-10T11:34:28.485660+00:00\", \"id\": \"1ef3eb05-e0d1-651b-8004-15f129f5f4fb\", \"channel_values\": {\"__root__\": 4, \"add_one\": \"add_one\"}, \"channel_versions\": {\"__start__\": 5, \"__root__\": 6, \"start:add_one\": 6, \"add_one\": 6}, \"versions_seen\": {\"__start__\": {\"__start__\": 4}, \"add_one\": {\"start:add_one\": 5}}, \"pending_sends\": []}', 'metadata': b'{\"source\": \"loop\", \"step\": 4, \"writes\": {\"add_one\": 4}}', 'parent_ts': '1ef3eb05-e0bd-6c9c-8003-aa9cb0fdedc1'}\n"
          ]
        }
      ],
      "source": [
        "from pymongo import MongoClient\n",
        "\n",
        "client = MongoClient(MONGO_URI)\n",
        "database = client[\"checkpoints_db\"]\n",
        "collection = database[\"checkpoints_collection\"]\n",
        "\n",
        "for doc in collection.find():\n",
        "    print(doc)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "myenv",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
